import math
import numpy as np
import pandas as pd
import torch
import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
matplotlib.use('agg')


def plot_labels(y_probs, y_true, epoch,  writer, tag):
    n_cols = 8
    n_rows = math.ceil(y_probs.size(0) / n_cols)

    fig, axs = plt.subplots(n_rows, n_cols)

    fig.set_figheight(5)
    fig.set_figwidth(10)

    y_preds = torch.argmax(y_probs, dim=1)

    y_probs = y_probs.cpu()

    for n in range(y_probs.size(0)):
        i, j = n // n_cols, n % n_cols
        axs[i, j].bar(range(1, y_probs.size(1) + 1), y_probs[n])
        if y_true is not None:
            axs[i, j].text(y_probs.size(1) / 2, .9, 'Pred {}, True {}'.format(y_preds[n].item(), y_true[n].item()),
                           fontsize=6, ha='center')
        else:
            axs[i, j].text(y_probs.size(1) / 2, .9, 'Pred {}'.format(y_preds[n].item()),
                           fontsize=6, ha='center')
        axs[i, j].set_ylim([0, 1])

    # Hide x labels and tick labels for top plots and y ticks for right plots.
    for ax in axs.flat:
        ax.label_outer()
    fig.tight_layout(pad=1.0)

    writer.add_figure(tag, fig, global_step=epoch)


def make_tabular_dataframe(data, n_cat_of_categorical_features, n_continuous_features):
    the_df = pd.DataFrame(data[:, 0:n_continuous_features])
    counter = n_continuous_features
    counter2 = n_continuous_features
    for ncat in n_cat_of_categorical_features:
        newdf = pd.DataFrame(data[:, counter:counter + ncat]).idxmax(axis=1)
        the_df['{}'.format(counter2)] = newdf
        counter += ncat
        counter2 += 1
    return the_df


def plot_tabular_df(the_df, n_cat_of_categorical_features, n_continuous_features):
    n_features = n_continuous_features + len(n_cat_of_categorical_features)
    _, ax = plt.subplots(1, n_features, figsize=(15, 10))
    bins_list = ['sturges'] * n_continuous_features + [np.arange(x + 1) for x in n_cat_of_categorical_features]
    for i, label in enumerate(the_df.columns):
        ax.reshape(n_features)[i].hist(the_df[label], bins=bins_list[i])
    fig = plt.gcf()
    gs = gridspec.GridSpec(int(np.ceil(n_features/4)), 4)
    for i in range(int(np.ceil(n_features/4))):
        for j in range(4):
            k = i + j * int(np.ceil(n_features/4))
            if k < len(ax):
                ax[k].set_position(gs[k].get_position(fig))
    return fig


def generated_synthetic_dataframe(rep_model, n_samples):
    with torch.no_grad():
        generated_df = rep_model.sample_dataframe(n_samples=n_samples)
    return generated_df


def plot_histograms_of_reconstructed_data(encoder, decoder, data_loader, epoch, writer, opt, stage):
    data_join_task = opt.data_join_task
    device=opt.device
    
    with torch.no_grad():
        for batch_idx, datapoint in enumerate(data_loader):
            if data_join_task:
                data, _, label = datapoint
            else:
                data, label = datapoint
            q_z_sample, _ = encoder.posterior(data.to(device))
            (p_x_mu, _), p_x_cat, _ = decoder.model(q_z_sample)
            max_p_x_cat = torch.cat([torch.argmax(p_feature, dim=1, keepdim=True) for p_feature in
                                             torch.split(p_x_cat, decoder.n_cats, dim=1)],
                                            dim=1)
            reconstruction = torch.cat([p_x_mu, max_p_x_cat.float()], dim=1)
            if batch_idx == 0:
                reconstructed_data = reconstruction
            else:
                reconstructed_data = torch.cat([reconstructed_data, reconstruction], dim=0)
        reconstructed_df = pd.DataFrame(reconstructed_data.cpu().numpy())
        fig = plot_tabular_df(reconstructed_df, opt.ncat_of_cat_features, opt.n_continuous_features)
        writer.add_figure('plot_features/stage{}_reconstructions'.format(stage), fig, global_step=epoch)
